#define ARMA_DONT_PRINT_ERRORS

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

using namespace Rcpp;
using namespace arma;



//[[Rcpp:plugins(cpp11)]]
#include <omp.h>
//[[Rcpp:plugins(openmp)]]


const double log2pi = std::log(2.0 * M_PI);




double am_mult2(const arma::mat& rhs)
{
  arma::mat ip= rhs.t();
 return (ip(0));
}


arma::mat mvrnormArma(int n, arma::mat mu, arma::mat sigma,
		std::function<double()> rn = norm_rand ) {
         int ncols = sigma.n_cols;
	arma::mat Y(n,ncols);
        Y.imbue(rn);
     arma::mat Z=arma::repmat(mu, 1, n).t() + Y * arma::chol(sigma);     
     return (Z.t());
}


arma::mat defaultRNG(int n, arma::mat mu, arma::mat sigma) {
  return mvrnormArma(n, mu, sigma);
}


arma::mat serial(int n, const arma::mat mu, arma::mat sigma, const arma::vec& s) {
  std::normal_distribution<double> nor(0, 1);
std::mt19937 engine(
static_cast<uint64_t>(s[omp_get_thread_num()] ) );
  return mvrnormArma(n, mu, sigma, [&](){return nor(engine);});
}





arma::vec dmvnrm_arma(arma::mat x,  
                      arma::rowvec mean,  
                      arma::mat sigma, 
                      bool logd = false) { 
    int n = x.n_rows;
    int xdim = x.n_cols;
    arma::vec out(n);
    arma::mat rooti = arma::trans(arma::inv(trimatu(arma::chol(sigma))));
    double rootisum = arma::sum(log(rooti.diag()));
    double constants = -(static_cast<double>(xdim)/2.0) * log2pi;
    
    for (int i=0; i < n; ++i) {
        arma::vec z = rooti * arma::trans( x.row(i) - mean) ;    
        out(i)      = constants - 0.5 * arma::sum(z%z) + rootisum;     
    }  
      
    if (logd == false) {
        out = exp(out);
    }
    return(out);
}







double vm_convert(const arma::mat& rhs)
{
  arma::mat ip= rhs;
 return (ip(0));
}



// [[Rcpp::export()]]
Rcpp::List mcmc_logit_index (arma::mat y,
arma::mat X,
arma::mat W,
arma::mat I_party,
arma::mat I_country,
arma::mat I_duration,
arma::mat I_index_duration,
arma::colvec table_duration,
arma::mat I_colindex_party,
arma::colvec table_duration_colparty,
arma::mat I_colindex_country,
arma::colvec table_duration_colcountry,
arma::mat I_outcome,
arma::mat Cov_Beta,
arma::mat alphastart,
arma::mat beta2start,
arma::mat beta3start,
arma::mat reffect_partystart,
arma::mat reffect_countrystart,
arma::rowvec densitybeta,
arma::mat DiagWish, 
arma::uvec icol0,
arma::uvec icol1,
arma::uvec icol2,
int mcmc = 100,
int burn=10,
int thin=10,
int chains=3
 ) {

 // Dimensions
int n_duration= I_duration.n_cols ;
int nx= X.n_cols ;
int N=X.n_rows;
int r=W.n_cols;
int n_country=I_country.n_cols;
int n_party=I_party.n_cols;


 // Current Containers
 arma::mat alpha = alphastart ;
  arma::mat alphanew(nx*2,1);
alphanew.fill(0.0) ;




arma::vec seeds = linspace(0, 2, chains);


  arma::mat beta2 = beta2start ;
  arma::mat beta3 = beta3start ;
  arma::mat beta2new(r,1);
  beta2new.fill(0.0);

arma::mat chucha(r,1);
chucha.fill(0.0);


arma::mat chucha2(r,1);
chucha2.fill(0.0);

  arma::mat beta3new(r,1);
  beta3new.fill(0.0);

 arma::mat reffect_party = reffect_partystart ;
 arma::mat reffect_country = reffect_countrystart ;
 



arma::mat etaold_alpha(N, 3) ;
 etaold_alpha.fill(0.0) ;
 arma::mat etanew_alpha(N, 3) ;
 etanew_alpha.fill(0.0) ;


 arma::mat etaold_beta(N, 3) ;
 etaold_beta.fill(0.0) ;
 arma::mat etanew_beta(N, 3) ;
 etanew_beta.fill(0.0) ;




 arma::mat etaold_country(N, 3) ;
 etaold_country.fill(0.0) ;
 arma::mat etanew_country(N, 3) ;
 etanew_country.fill(0.0) ;
 arma::mat lcountry(n_country, 1) ;
 lcountry.fill(0.0) ;


 arma::mat etaold_party(N, 3) ;
 etaold_party.fill(0.0) ;
 arma::mat etanew_party(N, 3) ;
 etanew_party.fill(0.0) ;
 arma::mat lparty(n_party, 1) ;
 lparty.fill(0.0) ;





arma::mat dif_alpha(2*nx, n_duration-1);

 arma::mat Q(2*nx,2*nx,arma::fill::eye);
 int nk=Q.n_cols;





arma::mat sigmabetadensity(r,r,arma::fill::eye);

arma::mat reffect_country_new(n_country,2);
arma::mat reffect_party_new(n_party,2);

 arma::mat mu0(1, 2) ;
 mu0.fill(0.0) ;

double var=1.0;
arma::vec varscountry(2) ;
varscountry.fill(var) ;
 arma::mat sigma2country(2, 2) ;
 sigma2country.fill(0.0) ;
 sigma2country.diag() = varscountry ;

arma::vec varsparty(2) ;
varsparty.fill(var) ;
 arma::mat sigma2party(2, 2) ;
 sigma2party.fill(0.0) ;
 sigma2party.diag() = varsparty ;


 int lastit= (mcmc-burn)/thin	;

// Trace Containers
  arma::cube tracechucha(r, lastit,chains) ;
 arma::cube tracechucha2(r, lastit,chains) ;
  arma::cube tracealpha(nk*n_duration, lastit,chains) ;
 arma::cube traceQ(nk*nk, lastit,chains) ;
  arma::cube tracebeta2(r, lastit,chains) ;
  arma::cube tracebeta3(r, lastit,chains) ;
 arma::cube tracesigma2country(4, lastit,chains) ;
 arma::cube tracereffect_country(n_country*2, lastit,chains) ;
 arma::cube tracesigma2party(4, lastit,chains) ;
 arma::cube tracereffect_party(n_party*2, lastit,chains) ;



 
#pragma omp parallel num_threads(chains)
  {


std::mt19937 engine(
static_cast<uint64_t>(seeds[omp_get_thread_num()] ) );
    std::uniform_real_distribution<double> uni(0.0, 1.0);
     std::normal_distribution<double> nor(0, 1);

  std::random_device rd;
    std::mt19937 gen(rd());

int g1 = 4;
    double g2 = 0.5;
    std::gamma_distribution<> gam(g1, 1.0 / g2);



#pragma omp for 
for (int chain = 0; chain < chains; ++chain) {

 // START SAMPLING
for (int iter = 0 ; iter < mcmc ; ++iter) {


// DURATION PARAMETERS
for (int n = 0 ; n <n_duration; ++n) {


arma::uvec rowindice =  arma::conv_to<arma::uvec>::from(I_index_duration.rows(0,table_duration(n)).col(n));
arma::uvec colparty= arma::conv_to<arma::uvec>::from(I_colindex_party.rows(0,table_duration_colparty(n)).col(n));
arma::uvec colcountry= arma::conv_to<arma::uvec>::from(I_colindex_country.rows(0,table_duration_colcountry(n)).col(n));


etaold_alpha.submat(rowindice,icol0)=1/(1+exp(X.rows(rowindice)*alpha.rows(0,nx-1).col(n)+W.rows(rowindice)*beta2+I_party.submat(rowindice, colparty)*reffect_party.submat(colparty, icol0)+I_country.submat(rowindice, colcountry)*reffect_country.submat(colcountry, icol0))+
exp(X.rows(rowindice)*alpha.rows(nx, 2*nx-1).col(n)+W.rows(rowindice)*beta3+I_party.submat(rowindice, colparty)*reffect_party.submat(colparty, icol1)+I_country.submat(rowindice, colcountry)*reffect_country.submat(colcountry, icol1)));
etaold_alpha.submat(rowindice,icol1)=exp(X.rows(rowindice)*alpha.rows(0,nx-1).col(n)+W.rows(rowindice)*beta2+I_party.submat(rowindice, colparty)*reffect_party.submat(colparty, icol0)+I_country.submat(rowindice, colcountry)*reffect_country.submat(colcountry, icol0))%etaold_alpha.submat(rowindice,icol0);
etaold_alpha.submat(rowindice,icol2)=exp(X.rows(rowindice)*alpha.rows(nx,2*nx-1).col(n)+W.rows(rowindice)*beta3+I_party.submat(rowindice, colparty)*reffect_party.submat(colparty, icol1)+I_country.submat(rowindice, colcountry)*reffect_country.submat(colcountry, icol1))%etaold_alpha.submat(rowindice,icol0);



if (n == 0) {
alphanew=mvrnormArma(1,  alpha.col(n+1), Q, [&](){return nor(engine);});
} else if (n>0 && n<(n_duration-1)) {
alphanew=mvrnormArma(1,  0.5*alpha.col(n-1)+0.5*alpha.col(n+1), 0.5*Q, [&](){return nor(engine);});
} else if (n==(n_duration-1)) {
alphanew=mvrnormArma(1,  alpha.col(n-1), Q, [&](){return nor(engine);});
}


etanew_alpha.submat(rowindice,icol0)=1/(1+exp(X.rows(rowindice)*alphanew.rows(0,nx-1)+W.rows(rowindice)*beta2+I_party.submat(rowindice, colparty)*reffect_party.submat(colparty, icol0)+I_country.submat(rowindice, colcountry)*reffect_country.submat(colcountry, icol0))+
exp(X.rows(rowindice)*alphanew.rows(nx, 2*nx-1)+W.rows(rowindice)*beta3+I_party.submat(rowindice, colparty)*reffect_party.submat(colparty, icol1)+I_country.submat(rowindice, colcountry)*reffect_country.submat(colcountry, icol1)));
etanew_alpha.submat(rowindice,icol1)=exp(X.rows(rowindice)*alphanew.rows(0,nx-1)+W.rows(rowindice)*beta2+I_party.submat(rowindice, colparty)*reffect_party.submat(colparty, icol0)+I_country.submat(rowindice, colcountry)*reffect_country.submat(colcountry, icol0))%etanew_alpha.submat(rowindice,icol0);
etanew_alpha.submat(rowindice,icol2)=exp(X.rows(rowindice)*alphanew.rows(nx,2*nx-1)+W.rows(rowindice)*beta3+I_party.submat(rowindice, colparty)*reffect_party.submat(colparty, icol1)+I_country.submat(rowindice, colcountry)*reffect_country.submat(colcountry, icol1))%etanew_alpha.submat(rowindice,icol0);


double accept_alpha=vm_convert(sum(I_outcome.submat(rowindice, icol0)%log(etanew_alpha.submat(rowindice,icol0)),0)+sum(I_outcome.submat(rowindice, icol1)%log(etanew_alpha.submat(rowindice,icol1)),0)+sum(I_outcome.submat(rowindice, icol2)%log(etanew_alpha.submat(rowindice,icol2)),0)-
(sum(I_outcome.submat(rowindice, icol0)%log(etaold_alpha.submat(rowindice,icol0)),0)+sum(I_outcome.submat(rowindice, icol1)%log(etaold_alpha.submat(rowindice,icol1)),0)+sum(I_outcome.submat(rowindice, icol2)%log(etaold_alpha.submat(rowindice,icol2)),0)));


if(log(uni(engine))<accept_alpha) {
for (int l=0; l<nk; ++l) {
alpha(l,n)=alphanew(l);
}
}




}



for (int k = 0 ; k < nk ; ++k) {
for (int s=1; s<n_duration; ++s) {
dif_alpha(k,s-1)=alpha(k,s)-alpha(k,s-1);
}

g1=1.0+(n_duration-1)/2;
g2=(1.0+vm_convert((dif_alpha.col(k).t()*dif_alpha.col(k)))/2);
Q(k,k)=1.0/gam(engine);


}





etaold_beta.col(0)=1/(1+exp(sum(I_duration%(X*alpha.rows(0,nx-1)),1)+W*beta2+I_party*reffect_party.col(0)+I_country*reffect_country.col(0))+
exp(sum(I_duration%(X*alpha.rows(nx,2*nx-1)),1)+W*beta3+I_party*reffect_party.col(1)+I_country*reffect_country.col(1)));
etaold_beta.col(1)=exp(sum(I_duration%(X*alpha.rows(0,nx-1)),1)+W*beta2+I_party*reffect_party.col(0)+I_country*reffect_country.col(0))%etaold_beta.col(0);
etaold_beta.col(2)=exp(sum(I_duration%(X*alpha.rows(nx,2*nx-1)),1)+W*beta3+I_party*reffect_party.col(1)+I_country*reffect_country.col(1))%etaold_beta.col(0);



beta2new=mvrnormArma(1, beta2, Cov_Beta, [&](){return nor(engine);});
beta3new=mvrnormArma(1, beta3, Cov_Beta, [&](){return nor(engine);});


etanew_beta.col(0)=1/(1+exp(sum(I_duration%(X*alpha.rows(0,nx-1)),1)+W*beta2new+I_party*reffect_party.col(0)+I_country*reffect_country.col(0))+
exp(sum(I_duration%(X*alpha.rows(nx,2*nx-1)),1)+W*beta3new+I_party*reffect_party.col(1)+I_country*reffect_country.col(1)));
etanew_beta.col(1)=exp(sum(I_duration%(X*alpha.rows(0,nx-1)),1)+W*beta2new+I_party*reffect_party.col(0)+I_country*reffect_country.col(0))%etanew_beta.col(0);
etanew_beta.col(2)=exp(sum(I_duration%(X*alpha.rows(nx,2*nx-1)),1)+W*beta3new+I_party*reffect_party.col(1)+I_country*reffect_country.col(1))%etanew_beta.col(0);

double accept_beta=vm_convert((sum(I_outcome.col(0)%log(etanew_beta.col(0)),0)+sum(I_outcome.col(1)%log(etanew_beta.col(1)),0)+sum(I_outcome.col(2)%log(etanew_beta.col(2)),0))-
			(sum(I_outcome.col(0)%log(etaold_beta.col(0)),0)+sum(I_outcome.col(1)%log(etaold_beta.col(1)),0)+sum(I_outcome.col(2)%log(etaold_beta.col(2)),0)))+
(am_mult2(dmvnrm_arma(beta2new.t(),densitybeta, sigmabetadensity, true))-am_mult2(dmvnrm_arma(beta2.t(),densitybeta, sigmabetadensity, true)))+
(am_mult2(dmvnrm_arma(beta3new.t(),densitybeta, sigmabetadensity, true))-am_mult2(dmvnrm_arma(beta3.t(),densitybeta, sigmabetadensity, true)));


if(log(uni(engine))<accept_beta) {
for (int l=0; l<r; ++l) {
beta2(l)=beta2new(l);
beta3(l)=beta3new(l);
}
}



 // TRACE
if ((iter+1)> burn & (iter+1)%thin==0) {
int j=((iter)-burn)/thin;
tracealpha.subcube(0, j, chain, nk*n_duration-1, j, chain)=vectorise(alpha);
traceQ.subcube(0,j,chain, nk*nk-1, j, chain)=vectorise(Q);
tracebeta2.subcube(0, j, chain, (r-1), j, chain)=beta2;
tracebeta3.subcube(0, j, chain, (r-1), j, chain)=beta3;
tracereffect_country.subcube(0, j, chain, 2*n_country-1, j, chain)=vectorise(reffect_country);
tracesigma2country.subcube(0, j, chain, 3, j, chain)=vectorise(sigma2country);
tracereffect_party.subcube(0, j, chain, 2*n_party-1, j, chain)=vectorise(reffect_party);
tracesigma2party.subcube(0, j, chain, 3, j, chain)=vectorise(sigma2party);

} //end trace
}

}  // end iterations
} // end chains

// Returns
Rcpp::List ret;
ret["alpha"] = tracealpha ;
 ret["beta2"] = tracebeta2 ;
 ret["beta3"] = tracebeta3 ;
 ret["Q"] = traceQ ;
ret["reffect_country"] = tracereffect_country;
ret["sigma2country"]=tracesigma2country;
ret["reffect_party"] = tracereffect_party;
ret["sigma2party"]=tracesigma2party;
 

 return(ret) ;
 

}

